package com.tca.common.classload;

import com.tca.common.core.utils.ValidateUtils;
import lombok.extern.slf4j.Slf4j;

import java.io.*;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author zhoua
 * @date 2023/12/4 21:57
 */
@Slf4j
public class ThirdClazzClassLoader extends ClassLoader{

    /**
     * 路径
     */
    private final String directory;

    /**
     * 用于单例模式, 即当前类加载器是否需要单例模式
     */
    private static Map<String, ThirdClazzClassLoader> classLoaderMap = new ConcurrentHashMap<>();

    private ThirdClazzClassLoader(String directory) {
        this.directory = directory;
    }

    /**
     * 获取类加载器
     * @param directory
     * @param isSingleton
     * @return
     * @throws Exception
     */
    public static ClassLoader getClassLoader(String directory, boolean isSingleton) throws Exception {
        if (isSingleton) {
            ThirdClazzClassLoader thirdClazzClassLoader = classLoaderMap.get(directory);
            if (ValidateUtils.isNotEmpty(thirdClazzClassLoader)) {
                return thirdClazzClassLoader;
            }
        }

        ThirdClazzClassLoader thirdClazzClassLoader = new ThirdClazzClassLoader(directory);

        if (isSingleton) {
            classLoaderMap.put(directory, thirdClazzClassLoader);
        }

        return thirdClazzClassLoader;
    }

    /**
     * 获取类加载器(单例)
     * @param directory
     * @return
     * @throws Exception
     */
    public static ClassLoader getClassLoader(String directory) throws Exception {
        return getClassLoader(directory, true);
    }

    @Override
    protected Class<?> findClass(String name) throws ClassNotFoundException {
        // 将类名转化为目录
        String fileName = this.directory + (this.directory.endsWith(File.separator)? "": File.separator) +
                name.replace(".", File.separator) + ".class";
        File file = new File(fileName);
        if (!file.exists()) {
            log.error("class file not found, file = {}", fileName);
            throw new ClassNotFoundException();
        }
        InputStream in = null;
        ByteArrayOutputStream bos = null;
        try {
            // 构建输入流
            in = new FileInputStream(file);
            // 构建字节输出流
            bos = new ByteArrayOutputStream();
            byte[] buf = new byte[1024];
            int length = 0;
            while ((length = in.read(buf)) != -1) {
                bos.write(buf, 0, length);
            }

            byte[] bytes = bos.toByteArray();
            return defineClass(name, bytes, 0, bytes.length);
        } catch (Exception e) {
            log.error("加载文件出错", e);
            throw new ClassNotFoundException();
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (bos != null) {
                    bos.close();
                }

            } catch (IOException e) {
                log.error("关闭流出错", e);
            }
        }
    }


}
